# ------------------------------------------------------------------------------
# Predicts GBMs for doctors' simplified risk model
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bsub -q big -R "rusage[mem=50000]" bash 02_predict_gbms.sh
# ------------------------------------------------------------------------------

# Seeding ----------------------------------------------------------------------
set.seed(930)

# Libraries --------------------------------------------------------------------
message("Loading libraries...")
library(yaml)
library(data.table)
library(tidyverse)
library(glue)
library(Matrix)
library(xgboost)
library(testit) # assert
library(here) # here() relative filepaths
library(MLmetrics)
library(Metrics)

u <- modules::use(here("lib", "util.R"))
temp <- here(
  "code", "06_physician_boundedness", "02_behavioral_gbms", "temp"
)

# Helpers ----------------------------------------------------------------------
measure_performance <- function(score, stent, test, which_obs = seq_along(score)) {
  score <- score[which_obs]
  stent <- stent[which_obs]
  test <- test[which_obs]

  design <- crossing(
    target_name = c("stent_or_cabg_010_day", "test_010_day") # ,
    # measure_name = c("AUC", "r2")
  )

  performance_tb <- mutate(design,
    score = map(target_name, switch,
      stent_or_cabg_010_day = score[test],
      test_010_day = score
    ),
    target = map(target_name, switch,
      stent_or_cabg_010_day = stent[test],
      test_010_day = test
    )
  )

  transmute(performance_tb,
    target_name,
    # measure_name,
    auc = map2_dbl(target, score, auc),
    r2 = map2_dbl(target, score, calibrated_r2),
    logloss = map2_dbl(target, score, LogLoss)
  ) %>%
    gather(
      key = "measure_name",
      value = "performance",
      auc, r2, logloss
    )
}
calibrated_r2 <- function(actual, predicted) {
  summary(lm(actual ~ predicted))$r.squared
}

# Load Data --------------------------------------------------------------------
message("Loading data...")
paths <- read_yaml(here("lib", "filepaths.yml"))
cohort_fp <- file.path(paths$modeling$dir, "cohorts", "random")
x <- readRDS(paths$features$test)
ids <- readRDS(file.path(cohort_fp, "test_cohort.rds"))
gbm_list <- readRDS(file.path(temp, "gbms__stent_or_cabg_010_day__tested.rds"))
gbm_dsmall_list <- readRDS(file.path(temp, "gbms__dsmall__stent_or_cabg_010_day__tested.rds"))

# Predict for each model -------------------------------------------------------
message("Getting predictions for each model...")
keep_obs <- which(!ids$exclude)
ids <- ids[keep_obs, ] %>% setDT()

num_iterations <- 500
niters <- union(1:20, seq(25, 100, 5)) %>%
  union(seq(150, num_iterations, 50))

# predictions_list <- list()
performance_list <- list()
performance_dsmall_list <- list()

# for(i in 1:2){
for (i in 1:length(gbm_list)) {
  model_name <- glue("model_{i}")
  model <- gbm_list[[model_name]]
  message("Scores for ", model_name)
  score <- map(niters, ~ predict(model, x, ntreelimit = .x))

  prediction_tb <- tibble(niters, score) %>%
    mutate(score = map(score, `[`, keep_obs))

  performance_tb <- prediction_tb %>%
    mutate(
      performance = map(
        score, measure_performance,
        stent = ids$stent_or_cabg_010_day,
        test = ids$test_010_day
      )
    ) %>%
    unnest(performance)
  performance_list[[i]] <- performance_tb
  rm(list = c("score", "prediction_tb"))

  for (j in 2:4) {
    model_name <- glue("model_{i}_d{j}")
    model_dsmall <- gbm_dsmall_list[[model_name]]
    score_dsmall <- predict(model_dsmall, x, ntreelimit = 1)[keep_obs]

    performance_tb <- score_dsmall %>%
      measure_performance(
        stent = ids$stent_or_cabg_010_day, test = ids$test_010_day
      )
    performance_dsmall_list[[model_name]] <- performance_tb
  }
}

# Save -------------------------------------------------------------------------
message("Saving predictions...")
# write_rds(predictions_list, file.path(temp, "predictions_list.rds"))
write_rds(performance_list, file.path(temp, "performance_list.rds"))
write_rds(performance_dsmall_list, file.path(temp, "performance_dsmall_list.rds"))

message("Done.")
